{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([-0.8394994 , -0.5433606 ,  0.31996232], dtype=float32)"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import gym\n",
    "\n",
    "\n",
    "#定义环境\n",
    "class MyWrapper(gym.Wrapper):\n",
    "    def __init__(self):\n",
    "        env = gym.make('Pendulum-v1', render_mode='rgb_array')\n",
    "        super().__init__(env)\n",
    "        self.env = env\n",
    "        self.step_n = 0\n",
    "\n",
    "    def reset(self):\n",
    "        state, _ = self.env.reset()\n",
    "        self.step_n = 0\n",
    "        return state\n",
    "\n",
    "    def step(self, action):\n",
    "        state, reward, terminated, truncated, info = self.env.step(action)\n",
    "        done = terminated or truncated\n",
    "        self.step_n += 1\n",
    "        if self.step_n >= 200:\n",
    "            done = True\n",
    "        return state, reward, done, info\n",
    "\n",
    "\n",
    "env = MyWrapper()\n",
    "\n",
    "env.reset()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAakAAAGiCAYAAABd6zmYAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAk50lEQVR4nO3df3DU9YH/8ddnk+wSEnaTQLJLDlLw4CumCKegsNofZ0mJGC0UnHoOWk45Pb3giXSck55i27uZMHp3tXoW765Tca4KHTzxB4oaA4aqATEQRdRoWzRR3IRf2U2Q/Nz39w+HravRJvBJ9p3k+ZjZGfP5vPPOez9N98l+9rMbxxhjBACAhTypXgAAAF+GSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArJWySN1///2aNGmSRo0apTlz5ujVV19N1VIAAJZKSaR++9vfatWqVbrzzju1Z88ezZw5U6WlpWpubk7FcgAAlnJS8QGzc+bM0Xnnnaf//M//lCTF43FNnDhRN910k2677bbBXg4AwFLpg/0DOzs7VVtbq9WrVye2eTwelZSUqKamptfv6ejoUEdHR+LreDyuo0ePauzYsXIcZ8DXDABwlzFGra2tKiwslMfz5Sf1Bj1Shw8fVk9Pj4LBYNL2YDCod955p9fvqaio0E9/+tPBWB4AYBA1NjZqwoQJX7p/0CN1KlavXq1Vq1Ylvo5GoyoqKlJjY6P8fn8KVwYAOBWxWEwTJ07UmDFjvnLcoEdq3LhxSktLU1NTU9L2pqYmhUKhXr/H5/PJ5/N9Ybvf7ydSADCE/bmXbAb96j6v16tZs2apqqoqsS0ej6uqqkrhcHiwlwMAsFhKTvetWrVKy5Yt0+zZs3X++efrnnvu0fHjx3XNNdekYjkAAEulJFJXXHGFDh06pDVr1igSieiv/uqv9Oyzz37hYgoAwMiWkvdJna5YLKZAIKBoNMprUgAwBPX1cZzP7gMAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgrX5HaseOHbrssstUWFgox3H0+OOPJ+03xmjNmjUaP368MjMzVVJSovfeey9pzNGjR7V06VL5/X7l5ORo+fLlamtrO607AgAYfvodqePHj2vmzJm6//77e91/11136d5779UDDzygXbt2KSsrS6WlpWpvb0+MWbp0qfbv36/Kykpt2bJFO3bs0PXXX3/q9wIAMDyZ0yDJbN68OfF1PB43oVDI3H333YltLS0txufzmQ0bNhhjjHnrrbeMJLN79+7EmK1btxrHccxHH33Up58bjUaNJBONRk9n+QCAFOnr47irr0kdOHBAkUhEJSUliW2BQEBz5sxRTU2NJKmmpkY5OTmaPXt2YkxJSYk8Ho927drV67wdHR2KxWJJNwDA8OdqpCKRiCQpGAwmbQ8Gg4l9kUhEBQUFSfvT09OVl5eXGPN5FRUVCgQCidvEiRPdXDYAwFJD4uq+1atXKxqNJm6NjY2pXhIAYBC4GqlQKCRJampqStre1NSU2BcKhdTc3Jy0v7u7W0ePHk2M+Tyfzye/3590AwAMf65GavLkyQqFQqqqqkpsi8Vi2rVrl8LhsCQpHA6rpaVFtbW1iTHbtm1TPB7XnDlz3FwOAGCIS+/vN7S1ten3v/994usDBw6orq5OeXl5Kioq0sqVK/Wv//qvmjp1qiZPnqw77rhDhYWFWrRokSTprLPO0sUXX6zrrrtODzzwgLq6urRixQr9zd/8jQoLC127YwCAYaC/lw1u377dSPrCbdmyZcaYTy9Dv+OOO0wwGDQ+n8/MmzfP1NfXJ81x5MgRc+WVV5rs7Gzj9/vNNddcY1pbW12/dBEAYKe+Po47xhiTwkaeklgspkAgoGg0yutTADAE9fVxfEhc3QcAGJmIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1upXpCoqKnTeeedpzJgxKigo0KJFi1RfX580pr29XeXl5Ro7dqyys7O1ZMkSNTU1JY1paGhQWVmZRo8erYKCAt16663q7u4+/XsDABhW+hWp6upqlZeXa+fOnaqsrFRXV5fmz5+v48ePJ8bccssteuqpp7Rp0yZVV1fr4MGDWrx4cWJ/T0+PysrK1NnZqVdeeUUPPfSQ1q9frzVr1rh3rwAAw4M5Dc3NzUaSqa6uNsYY09LSYjIyMsymTZsSY95++20jydTU1BhjjHnmmWeMx+MxkUgkMWbdunXG7/ebjo6OPv3caDRqJJloNHo6ywcApEhfH8dP6zWpaDQqScrLy5Mk1dbWqqurSyUlJYkx06ZNU1FRkWpqaiRJNTU1OvvssxUMBhNjSktLFYvFtH///l5/TkdHh2KxWNINADD8nXKk4vG4Vq5cqQsvvFDTp0+XJEUiEXm9XuXk5CSNDQaDikQiiTGfDdTJ/Sf39aaiokKBQCBxmzhx4qkuGwAwhJxypMrLy/Xmm29q48aNbq6nV6tXr1Y0Gk3cGhsbB/xnAgBSL/1UvmnFihXasmWLduzYoQkTJiS2h0IhdXZ2qqWlJenZVFNTk0KhUGLMq6++mjTfyav/To75PJ/PJ5/PdypLBQAMYf16JmWM0YoVK7R582Zt27ZNkydPTto/a9YsZWRkqKqqKrGtvr5eDQ0NCofDkqRwOKx9+/apubk5MaayslJ+v1/FxcWnc18AAMNMv55JlZeX65FHHtETTzyhMWPGJF5DCgQCyszMVCAQ0PLly7Vq1Srl5eXJ7/frpptuUjgc1ty5cyVJ8+fPV3Fxsa6++mrdddddikQiuv3221VeXs6zJQBAEscYY/o82HF63f7ggw/qb//2byV9+mbeH/3oR9qwYYM6OjpUWlqqX/7yl0mn8j744APdeOONevHFF5WVlaVly5Zp7dq1Sk/vWzNjsZgCgYCi0aj8fn9flw8AsERfH8f7FSlbECkAGNr6+jjOZ/cBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrpad6AQD+xBjzpfscxxnElQB2IFKABUxPj7pbWxXbs0ctu3ervbFRPSdOKN3vV9aUKcr9xjc0+i//UmlZWcQKIwqRAlIs3tGhlp071fTUU/rkvfekzzyb6jp0SCf+8Acd2b5dgXPPVcGiRco+6yxChRGDSAEpZIzRoeefV2TTJnW3tHz5uM5OtezcqfaPP1bR9dcre/p0QoURgQsngBQxPT068sILOvjww18ZqM9q/+ADNfz3f6vt7be/8vUrYLggUkCKHH/3XUU2bVL8k0+Stn90/Li2NDZqwx//qBcOHtTxrq6k/e0ffKCPN2xQT1vbYC4XSAlO9wEpEO/qUvS119QRiSS2GWN0oK1Nd+7dq/fb2tTe0yN/Roam5+bq3847TxmeP/2bsvX119Wya5fGzpvHaT8MazyTAlKg68gRNT32WNK2P7a16bqXX9bb0ahO9PTISIp2denl5mbdvGuXjrS3J40/+Mgj6mxqGsRVA4OPSAEpYIyR6elJ2nbP/v2Kfu7U3kmvHj6syoMHk7Z1HT6sj/73f9XT0TFg6wRSjUgBQ1hs71611tVxEQWGLSIFDGE9bW3649q16jx8ONVLAQYEkQIsUTZxojK+5CKISdnZmpGX1+s+09OjQ1u2DOTSgJQhUkAKZAQCyrvooqRtpYWFuvOcczQqLS3xf8w0x9FYn0//ft55Ks7J+dL5jmzfrlbeO4VhiEvQgRTwZGYqNxxW9LXX1NPaKunTD5AtLSzUhNGjteXDD3WkvV2TsrN1xeTJGuvzfeV83S0t+njDBk36x3+Ud9y4wbgLwKAgUkAKOI4jz6hR8ni96vnc9um5uZqem9vvOVvr6nR0xw4FFy6Uk5bm3mKBFOJ0H5AiGbm5Svf7XZ0z8n//pw7eO4VhhEgBKTKqqEh5f/3XctLdO6HR09qqP1RUqPv4cdfmBFKJSAEp4jiOgt/7nnIuuMDVeTubm3Xs5ZddnRNIFSIFpJLHo9D3v6+00aNdmzJ+4oQ+/PWv1bpvH1f7YcgjUkAKOY6jzMmTNeHv/s7dUH3yiQ49++wXPnoJGGqIFJBijsejnLlzlTl5sqvzHnvpJUV++1ueTWFII1KABdKzs/W1FSuUcQqXnn8pY3TslVfU8dFH7s0JDDIiBVjCV1io0BVXuDpne2Ojmh5/XD0nTrg6LzBYiBRgCcdxlBsOyz97tqvzHqup0fH6ek77YUgiUoBF0nNylHP++fKMGuXanD2trTrwb/+m7ljMtTmBwUKkAIs4jqP8iy9WcPFiycU/C9/d1qbmJ5/kaj8MOUQKsFDBJZfI4/W6N2E8rkNbt+rYyy9z2g9DCpECLJQ2Zoym/OQnrn62X09bmw6/8IK6W1oIFYYMIgVYyHEcZU2dqtxvfMPVeVvr6hR59FFX5wQGEpECLOXxehVaskSZkya5Ou/hF17QiQMHXJ0TGChECrCYNz9fRTfe6OpFFPETJ9T4P/+jzkOHXJsTGChECrDc6ClTNLakxNU52/bv1+GqKsW7u12dF3AbkQIs58nI0Pgf/EDZX/+6q/Me2rJF7R9+yEUUsBqRAoYAb0GB8i66SJ7MTNfm7I7F9Ps771S8s9O1OQG3ESlgCHAcR+NKSjTO5dN+3ceP60hlpatzAm4iUsAQ4Xg8Kli4UGlZWa7NaTo7dfiFF/TJ++9z2g9WIlLAEOLNz9ekW25x9U2+J/74R0UefVTx9nbX5gTcQqSAIcRxHI2ZPl3Z06e7Ou+xHTv4c/OwEpEChpi00aM1Yfly+f7iL1yd96P169V15IircwKni0gBQ5AvP1+FV13l6pztH36oD9ev5w8kwipEChiiAueco9xvftPVOVtff13H332X036wBpEChihPZqZyL7hAaaNHuzZndzSqP1RUqOvoUdfmBE4HkQKGKMdxlHvhhSq8+mo5GRmuzRv/5BM1P/20a/MBp4NIAUPcuJISpWVnuzrnoS1bdOi55zjth5TrV6TWrVunGTNmyO/3y+/3KxwOa+vWrYn97e3tKi8v19ixY5Wdna0lS5aoqakpaY6GhgaVlZVp9OjRKigo0K233qpuPuQSOGWO16v/9y//ovScHNfmjLe3q+WVV7jaDynXr0hNmDBBa9euVW1trV577TV95zvf0cKFC7V//35J0i233KKnnnpKmzZtUnV1tQ4ePKjFixcnvr+np0dlZWXq7OzUK6+8ooceekjr16/XmjVr3L1XwAjiOI5GFRYqf8ECV+eN7d2rgw8/rHhXl6vzAv3hmNN8Pp+Xl6e7775bl19+ufLz8/XII4/o8ssvlyS98847Ouuss1RTU6O5c+dq69atuvTSS3Xw4EEFg0FJ0gMPPKB/+qd/0qFDh+T1evv0M2OxmAKBgKLRqPwuvvMeGMo6jxzRgX//d7W9+aZ7k3o8Kr7vPmVOnOjenID6/jh+yq9J9fT0aOPGjTp+/LjC4bBqa2vV1dWlks98AOa0adNUVFSkmpoaSVJNTY3OPvvsRKAkqbS0VLFYLPFsrDcdHR2KxWJJNwDJvGPHqujv/16eUaPcmzQe1/v33KPu48fdmxPoh35Hat++fcrOzpbP59MNN9ygzZs3q7i4WJFIRF6vVzmfOy8eDAYViUQkSZFIJClQJ/ef3PdlKioqFAgEEreJ/KsO6NWoiROVf/HFrs7Z/uGHOva738n09Lg6L9AX/Y7UmWeeqbq6Ou3atUs33nijli1bprfeemsg1pawevVqRaPRxK2xsXFAfx4wVDkej4KLF8s/e7Zrc8ZPnNDBhx/WJ7//PVf7YdD1O1Jer1dTpkzRrFmzVFFRoZkzZ+oXv/iFQqGQOjs71dLSkjS+qalJoVBIkhQKhb5wtd/Jr0+O6Y3P50tcUXjyBqB36YGAxn7nO0obM8a1ObujUTU9+aQMV+JikJ32+6Ti8bg6Ojo0a9YsZWRkqKqqKrGvvr5eDQ0NCofDkqRwOKx9+/apubk5MaayslJ+v1/FxcWnuxQA+tObfIMLF0oe994Keezll3WIN/likKX3Z/Dq1au1YMECFRUVqbW1VY888ohefPFFPffccwoEAlq+fLlWrVqlvLw8+f1+3XTTTQqHw5o7d64kaf78+SouLtbVV1+tu+66S5FIRLfffrvKy8vl8/kG5A4CI5HjOMq/5BI1b9mi7s+d3Thl8bgOv/CC/LNmcbUfBk2//pnV3NysH/7whzrzzDM1b9487d69W88995y++93vSpJ+/vOf69JLL9WSJUv0rW99S6FQSI899lji+9PS0rRlyxalpaUpHA7rqquu0g9/+EP97Gc/c/deAVBaVpb+8sc/VkZenmtztjc06OONG9Xd2uranMBXOe33SaUC75MC+ibe0aEPH3xQh555xtV5z7jtNuWEw3Icx9V5MXIM+PukANjP4/Mp9IMfKGvaNFfnbfzVr9w7jQh8BSIFDHMZubn6i6uvdvUiiq7Dh3XgnnsU7+hwbU6gN0QKGOYcx1HWWWdpXGmpq/Oe+MMfFNu7l/dOYUARKWAEcNLSFFy0SNlf/7prc3bHYmp++ml1HT1KqDBgiBQwAjiOo1Hjx2tcaanSsrJcm7f19dfV9OSTrs0HfB6RAkaQ3AsukHfcOFfnPPzcc3xkEgYMkQJGEI/Xqyk/+Ykyxo51bc74J5/o4G9+o65jx1ybEziJSAEjTEZursb/4Aeuzhnbu1eHn3tOcT7bDy4jUsAI43g8ygmHFZgzx9V5Dz/7rLqOHOG0H1xFpIARKCMnRxOuvdbV035dx47pvTVreO8UXEWkgBFq1Pjxyr/kElfn7Dp6VEe3b3d1ToxsRAoYwQouvVS53/qWa/PFOzp08JFHFN2zh9N+cAWRAkYwz6hRGldS4uonpXdHozq6fTun/eAKIgWMYI7jaMzMmRp/xRVyvF7X5j1aXa2mz/yZHuBUESlghHMcR3kXXeT6m3yPbNumjs/8FW7gVBApAPL4fJqyZo0yXAxVZ3OzDv7mN+o5ccK1OTHyECkAchxH3vx8jSspcXXeWF2dWvfv5yIKnDIiBUCS5MnIUP4ll8g/e7Zrc3a3tOiD++5TdzTq2pwYWYgUgIT0QECFS5fKk5np2pzdx46p6YknZHp6XJsTIweRApDgOI5Gn3GGCi67zNV5Dz/3nI69/DKn/dBvRArAFxSUlck/a5Zr8/W0tenItm3qOnrUtTkxMhApAEkcx1FGbq7yFyxQek6Oa/PG9uxR+4cfujYfRgYiBaBX/nPOUeakSa7O2dncLBOPuzonhjciBaBXnowMnXHrrfKNH+/anB8++KBMV5dr82H4I1IAvlRadrYmXHuta/PF29tdmwsjA5EC8KUcx1H29OnKu+iiVC8FIxSRAvCV0rOyVHjVVa6/PgX0BZEC8Gf58vOVv2CB5OEhA4OL3zgAfTL2u99Vzpw5pzVH/iWXyElPd2lFGAmIFIA+cdLSFFq8+JTfO5Wek6O8b36TZ2PoF35bAPSJ4zjKPOMMhS6/vN+f7efJzFTo8suVecYZchxngFaI4YhIAegzT0aGxpWUKH/Bgj6ftnPS05W/YIHGlZTIk5ExwCvEcMPJYQD9kjZ6tAqXLlVaVpYOP/+8Opubpd4+ONZx5C0o0Lj58xVctIhA4ZQQKQD95snIUHDhQmUXF6tl5061vvGGOpqaFD9xQp7MTPmCQY2ZMUM5c+cqa+pUAoVTRqQAnBKP16sxX/+6sqZMUXdrq+KdnTI9PXLS0uTxepU2ZozSfL5ULxNDHJECcFo8Pp+8xAgDhAsnAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtU4rUmvXrpXjOFq5cmViW3t7u8rLyzV27FhlZ2dryZIlampqSvq+hoYGlZWVafTo0SooKNCtt96q7u7u01kKAGAYOuVI7d69W//1X/+lGTNmJG2/5ZZb9NRTT2nTpk2qrq7WwYMHtXjx4sT+np4elZWVqbOzU6+88ooeeughrV+/XmvWrDn1ewEAGJ7MKWhtbTVTp041lZWV5tvf/ra5+eabjTHGtLS0mIyMDLNp06bE2LfffttIMjU1NcYYY5555hnj8XhMJBJJjFm3bp3x+/2mo6OjTz8/Go0aSSYajZ7K8gEAKdbXx/FTeiZVXl6usrIylZSUJG2vra1VV1dX0vZp06apqKhINTU1kqSamhqdffbZCgaDiTGlpaWKxWLav39/rz+vo6NDsVgs6QYAGP7S+/sNGzdu1J49e7R79+4v7ItEIvJ6vcrJyUnaHgwGFYlEEmM+G6iT+0/u601FRYV++tOf9nepAIAhrl/PpBobG3XzzTfr4Ycf1qhRowZqTV+wevVqRaPRxK2xsXHQfjYAIHX6Fana2lo1Nzfr3HPPVXp6utLT01VdXa17771X6enpCgaD6uzsVEtLS9L3NTU1KRQKSZJCodAXrvY7+fXJMZ/n8/nk9/uTbgCA4a9fkZo3b5727dunurq6xG327NlaunRp4r8zMjJUVVWV+J76+no1NDQoHA5LksLhsPbt26fm5ubEmMrKSvn9fhUXF7t0twAAw0G/XpMaM2aMpk+fnrQtKytLY8eOTWxfvny5Vq1apby8PPn9ft10000Kh8OaO3euJGn+/PkqLi7W1VdfrbvuukuRSES33367ysvL5fP5XLpbAIDhoN8XTvw5P//5z+XxeLRkyRJ1dHSotLRUv/zlLxP709LStGXLFt14440Kh8PKysrSsmXL9LOf/cztpQAAhjjHGGNSvYj+isViCgQCikajvD4FAENQXx/H+ew+AIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC10lO9gFNhjJEkxWKxFK8EAHAqTj5+n3w8/zJDMlJHjhyRJE2cODHFKwEAnI7W1lYFAoEv3T8kI5WXlydJamho+Mo7N9LFYjFNnDhRjY2N8vv9qV6OtThOfcNx6huOU98YY9Ta2qrCwsKvHDckI+XxfPpSWiAQ4JegD/x+P8epDzhOfcNx6huO05/XlycZXDgBALAWkQIAWGtIRsrn8+nOO++Uz+dL9VKsxnHqG45T33Cc+obj5C7H/Lnr/wAASJEh+UwKADAyECkAgLWIFADAWkQKAGCtIRmp+++/X5MmTdKoUaM0Z84cvfrqq6le0qDasWOHLrvsMhUWFspxHD3++ONJ+40xWrNmjcaPH6/MzEyVlJTovffeSxpz9OhRLV26VH6/Xzk5OVq+fLna2toG8V4MrIqKCp133nkaM2aMCgoKtGjRItXX1yeNaW9vV3l5ucaOHavs7GwtWbJETU1NSWMaGhpUVlam0aNHq6CgQLfeequ6u7sH864MqHXr1mnGjBmJN56Gw2Ft3bo1sZ9j1Lu1a9fKcRytXLkysY1jNUDMELNx40bj9XrNr3/9a7N//35z3XXXmZycHNPU1JTqpQ2aZ555xvzzP/+zeeyxx4wks3nz5qT9a9euNYFAwDz++OPm9ddfN9/73vfM5MmTzYkTJxJjLr74YjNz5kyzc+dO87vf/c5MmTLFXHnllYN8TwZOaWmpefDBB82bb75p6urqzCWXXGKKiopMW1tbYswNN9xgJk6caKqqqsxrr71m5s6day644ILE/u7ubjN9+nRTUlJi9u7da5555hkzbtw4s3r16lTcpQHx5JNPmqefftq8++67pr6+3vz4xz82GRkZ5s033zTGcIx68+qrr5pJkyaZGTNmmJtvvjmxnWM1MIZcpM4//3xTXl6e+Lqnp8cUFhaaioqKFK4qdT4fqXg8bkKhkLn77rsT21paWozP5zMbNmwwxhjz1ltvGUlm9+7diTFbt241juOYjz76aNDWPpiam5uNJFNdXW2M+fSYZGRkmE2bNiXGvP3220aSqampMcZ8+o8Bj8djIpFIYsy6deuM3+83HR0dg3sHBlFubq751a9+xTHqRWtrq5k6daqprKw03/72txOR4lgNnCF1uq+zs1O1tbUqKSlJbPN4PCopKVFNTU0KV2aPAwcOKBKJJB2jQCCgOXPmJI5RTU2NcnJyNHv27MSYkpISeTwe7dq1a9DXPBii0aikP304cW1trbq6upKO07Rp01RUVJR0nM4++2wFg8HEmNLSUsViMe3fv38QVz84enp6tHHjRh0/flzhcJhj1Ivy8nKVlZUlHROJ36eBNKQ+YPbw4cPq6elJ+h9ZkoLBoN55550UrcoukUhEkno9Rif3RSIRFRQUJO1PT09XXl5eYsxwEo/HtXLlSl144YWaPn26pE+PgdfrVU5OTtLYzx+n3o7jyX3Dxb59+xQOh9Xe3q7s7Gxt3rxZxcXFqqur4xh9xsaNG7Vnzx7t3r37C/v4fRo4QypSwKkoLy/Xm2++qZdeeinVS7HSmWeeqbq6OkWjUT366KNatmyZqqurU70sqzQ2Nurmm29WZWWlRo0alerljChD6nTfuHHjlJaW9oUrZpqamhQKhVK0KrucPA5fdYxCoZCam5uT9nd3d+vo0aPD7jiuWLFCW7Zs0fbt2zVhwoTE9lAopM7OTrW0tCSN//xx6u04ntw3XHi9Xk2ZMkWzZs1SRUWFZs6cqV/84hcco8+ora1Vc3Ozzj33XKWnpys9PV3V1dW69957lZ6ermAwyLEaIEMqUl6vV7NmzVJVVVViWzweV1VVlcLhcApXZo/JkycrFAolHaNYLKZdu3YljlE4HFZLS4tqa2sTY7Zt26Z4PK45c+YM+poHgjFGK1as0ObNm7Vt2zZNnjw5af+sWbOUkZGRdJzq6+vV0NCQdJz27duXFPTKykr5/X4VFxcPzh1JgXg8ro6ODo7RZ8ybN0/79u1TXV1d4jZ79mwtXbo08d8cqwGS6is3+mvjxo3G5/OZ9evXm7feestcf/31JicnJ+mKmeGutbXV7N271+zdu9dIMv/xH/9h9u7daz744ANjzKeXoOfk5JgnnnjCvPHGG2bhwoW9XoJ+zjnnmF27dpmXXnrJTJ06dVhdgn7jjTeaQCBgXnzxRfPxxx8nbp988klizA033GCKiorMtm3bzGuvvWbC4bAJh8OJ/ScvGZ4/f76pq6szzz77rMnPzx9Wlwzfdtttprq62hw4cMC88cYb5rbbbjOO45jnn3/eGMMx+iqfvbrPGI7VQBlykTLGmPvuu88UFRUZr9drzj//fLNz585UL2lQbd++3Uj6wm3ZsmXGmE8vQ7/jjjtMMBg0Pp/PzJs3z9TX1yfNceTIEXPllVea7Oxs4/f7zTXXXGNaW1tTcG8GRm/HR5J58MEHE2NOnDhh/uEf/sHk5uaa0aNHm+9///vm448/Tprn/fffNwsWLDCZmZlm3Lhx5kc/+pHp6uoa5HszcK699lrzta99zXi9XpOfn2/mzZuXCJQxHKOv8vlIcawGBn+qAwBgrSH1mhQAYGQhUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFr/H0hjbs3HjO6ZAAAAAElFTkSuQmCC",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from matplotlib import pyplot as plt\n",
    "\n",
    "%matplotlib inline\n",
    "\n",
    "\n",
    "#打印游戏\n",
    "def show():\n",
    "    plt.imshow(env.render())\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(-1.4391541481018066, -1374.6978214236847)"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "import random\n",
    "from IPython import display\n",
    "import math\n",
    "\n",
    "\n",
    "class SAC:\n",
    "    class ModelAction(torch.nn.Module):\n",
    "        def __init__(self):\n",
    "            super().__init__()\n",
    "            self.fc_state = torch.nn.Sequential(\n",
    "                torch.nn.Linear(3, 128),\n",
    "                torch.nn.ReLU(),\n",
    "            )\n",
    "            self.fc_mu = torch.nn.Linear(128, 1)\n",
    "            self.fc_std = torch.nn.Sequential(\n",
    "                torch.nn.Linear(128, 1),\n",
    "                torch.nn.Softplus(),\n",
    "            )\n",
    "\n",
    "        def forward(self, state):\n",
    "            #[b, 3] -> [b, 128]\n",
    "            state = self.fc_state(state)\n",
    "\n",
    "            #[b, 128] -> [b, 1]\n",
    "            mu = self.fc_mu(state)\n",
    "\n",
    "            #[b, 128] -> [b, 1]\n",
    "            std = self.fc_std(state)\n",
    "\n",
    "            #根据mu和std定义b个正态分布\n",
    "            dist = torch.distributions.Normal(mu, std)\n",
    "\n",
    "            #采样b个样本\n",
    "            #这里用的是rsample,表示重采样,其实就是先从一个标准正态分布中采样,然后乘以标准差,加上均值\n",
    "            sample = dist.rsample()\n",
    "\n",
    "            #样本压缩到-1,1之间,求动作\n",
    "            action = torch.tanh(sample)\n",
    "\n",
    "            #求概率对数\n",
    "            log_prob = dist.log_prob(sample)\n",
    "\n",
    "            #这个式子看不懂,但参照上下文理解,这个值应该描述的是动作的熵\n",
    "            entropy = log_prob - (1 - action.tanh()**2 + 1e-7).log()\n",
    "            entropy = -entropy\n",
    "\n",
    "            return action * 2, entropy\n",
    "\n",
    "    class ModelValue(torch.nn.Module):\n",
    "        def __init__(self):\n",
    "            super().__init__()\n",
    "            self.sequential = torch.nn.Sequential(\n",
    "                torch.nn.Linear(4, 128),\n",
    "                torch.nn.ReLU(),\n",
    "                torch.nn.Linear(128, 128),\n",
    "                torch.nn.ReLU(),\n",
    "                torch.nn.Linear(128, 1),\n",
    "            )\n",
    "\n",
    "        def forward(self, state, action):\n",
    "            #[b, 3+1] -> [b, 4]\n",
    "            state = torch.cat([state, action], dim=1)\n",
    "\n",
    "            #[b, 4] -> [b, 1]\n",
    "            return self.sequential(state)\n",
    "\n",
    "    def __init__(self):\n",
    "        self.model_action = self.ModelAction().to(device='cuda')\n",
    "\n",
    "        self.model_value1 = self.ModelValue().to(device='cuda')\n",
    "        self.model_value2 = self.ModelValue().to(device='cuda')\n",
    "\n",
    "        self.model_value_next1 = self.ModelValue().to(device='cuda')\n",
    "        self.model_value_next2 = self.ModelValue().to(device='cuda')\n",
    "\n",
    "        self.model_value_next1.load_state_dict(self.model_value1.state_dict())\n",
    "        self.model_value_next2.load_state_dict(self.model_value2.state_dict())\n",
    "\n",
    "        #这也是一个可学习的参数\n",
    "        self.alpha = torch.tensor(math.log(0.01))\n",
    "        self.alpha.requires_grad = True\n",
    "\n",
    "        self.optimizer_action = torch.optim.Adam(\n",
    "            self.model_action.parameters(), lr=3e-4)\n",
    "        self.optimizer_value1 = torch.optim.Adam(\n",
    "            self.model_value1.parameters(), lr=3e-3)\n",
    "        self.optimizer_value2 = torch.optim.Adam(\n",
    "            self.model_value2.parameters(), lr=3e-3)\n",
    "\n",
    "        #alpha也是要更新的参数,所以这里要定义优化器\n",
    "        self.optimizer_alpha = torch.optim.Adam([self.alpha], lr=3e-4)\n",
    "\n",
    "        self.loss_fn = torch.nn.MSELoss()\n",
    "\n",
    "    def get_action(self, state):\n",
    "        state = torch.FloatTensor(state).reshape(1, 3).to(device='cuda')\n",
    "        action, _ = self.model_action(state)\n",
    "        return action.item()\n",
    "\n",
    "    def test(self, play):\n",
    "        #初始化游戏\n",
    "        state = env.reset()\n",
    "\n",
    "        #记录反馈值的和,这个值越大越好\n",
    "        reward_sum = 0\n",
    "\n",
    "        #玩到游戏结束为止\n",
    "        over = False\n",
    "        while not over:\n",
    "            #根据当前状态得到一个动作\n",
    "            action = self.get_action(state)\n",
    "\n",
    "            #执行动作,得到反馈\n",
    "            state, reward, over, _ = env.step([action])\n",
    "            reward_sum += reward\n",
    "\n",
    "            #打印动画\n",
    "            if play and random.random() < 0.2:  #跳帧\n",
    "                display.clear_output(wait=True)\n",
    "                show()\n",
    "\n",
    "        return reward_sum\n",
    "\n",
    "    def _soft_update(self, model, model_next):\n",
    "        for param, param_next in zip(model.parameters(),\n",
    "                                     model_next.parameters()):\n",
    "            #以一个小的比例更新\n",
    "            value = param_next.data * 0.995 + param.data * 0.005\n",
    "            param_next.data.copy_(value)\n",
    "\n",
    "    def _get_target(self, reward, next_state, over):\n",
    "        #首先使用model_action计算动作和动作的熵\n",
    "        #[b, 4] -> [b, 1],[b, 1]\n",
    "        action, entropy = self.model_action(next_state)\n",
    "\n",
    "        #评估next_state的价值\n",
    "        #[b, 4],[b, 1] -> [b, 1]\n",
    "        target1 = self.model_value_next1(next_state, action)\n",
    "        target2 = self.model_value_next2(next_state, action)\n",
    "\n",
    "        #取价值小的,这是出于稳定性考虑\n",
    "        #[b, 1]\n",
    "        target = torch.min(target1, target2)\n",
    "\n",
    "        #exp和log互为反操作,这里是把alpha还原了\n",
    "        #这里的操作是在target上加上了动作的熵,alpha作为权重系数\n",
    "        #[b, 1] - [b, 1] -> [b, 1]\n",
    "        target += self.alpha.exp() * entropy\n",
    "\n",
    "        #[b, 1]\n",
    "        target *= 0.99\n",
    "        target *= (1 - over)\n",
    "        target += reward\n",
    "\n",
    "        return target\n",
    "\n",
    "    def _get_loss_action(self, state):\n",
    "        #计算action和熵\n",
    "        #[b, 3] -> [b, 1],[b, 1]\n",
    "        action, entropy = self.model_action(state)\n",
    "\n",
    "        #使用两个value网络评估action的价值\n",
    "        #[b, 3],[b, 1] -> [b, 1]\n",
    "        value1 = self.model_value1(state, action)\n",
    "        value2 = self.model_value2(state, action)\n",
    "\n",
    "        #取价值小的,出于稳定性考虑\n",
    "        #[b, 1]\n",
    "        value = torch.min(value1, value2)\n",
    "\n",
    "        #alpha还原后乘以熵,这个值期望的是越大越好,但是这里是计算loss,所以符号取反\n",
    "        #[1] - [b, 1] -> [b, 1]\n",
    "        loss_action = -self.alpha.exp() * entropy\n",
    "\n",
    "        #减去value,所以value越大越好,这样loss就会越小\n",
    "        loss_action -= value\n",
    "\n",
    "        return loss_action.mean(), entropy\n",
    "\n",
    "    def _get_loss_value(self, model_value, target, state, action, next_state):\n",
    "        #计算value\n",
    "        value = model_value(state, action)\n",
    "\n",
    "        #计算loss,value的目标是要贴近target\n",
    "        loss_value = self.loss_fn(value, target)\n",
    "        return loss_value\n",
    "\n",
    "    def train(self, state, action, reward, next_state, over):\n",
    "        #对reward偏移,为了便于训练\n",
    "        reward = (reward + 8) / 8\n",
    "\n",
    "        #计算target,这个target里已经考虑了动作的熵\n",
    "        #[b, 1]\n",
    "        target = self._get_target(reward, next_state, over)\n",
    "        target = target.detach()\n",
    "\n",
    "        #计算两个value loss\n",
    "        loss_value1 = self._get_loss_value(self.model_value1, target, state,\n",
    "                                           action, next_state)\n",
    "        loss_value2 = self._get_loss_value(self.model_value2, target, state,\n",
    "                                           action, next_state)\n",
    "\n",
    "        #更新参数\n",
    "        self.optimizer_value1.zero_grad()\n",
    "        loss_value1.backward()\n",
    "        self.optimizer_value1.step()\n",
    "\n",
    "        self.optimizer_value2.zero_grad()\n",
    "        loss_value2.backward()\n",
    "        self.optimizer_value2.step()\n",
    "\n",
    "        #使用model_value计算model_action的loss\n",
    "        loss_action, entropy = self._get_loss_action(state)\n",
    "        self.optimizer_action.zero_grad()\n",
    "        loss_action.backward()\n",
    "        self.optimizer_action.step()\n",
    "\n",
    "        #熵乘以alpha就是alpha的loss\n",
    "        #[b, 1] -> [1]\n",
    "        loss_alpha = (entropy + 1).detach() * self.alpha.exp()\n",
    "        loss_alpha = loss_alpha.mean()\n",
    "\n",
    "        #更新alpha值\n",
    "        self.optimizer_alpha.zero_grad()\n",
    "        loss_alpha.backward()\n",
    "        self.optimizer_alpha.step()\n",
    "\n",
    "        #增量更新next模型\n",
    "        self._soft_update(self.model_value1, self.model_value_next1)\n",
    "        self._soft_update(self.model_value2, self.model_value_next2)\n",
    "\n",
    "\n",
    "teacher = SAC()\n",
    "\n",
    "teacher.train(\n",
    "    torch.randn(5, 3).to(device='cuda'),\n",
    "    torch.randn(5, 1).to(device='cuda'),\n",
    "    torch.randn(5, 1).to(device='cuda'),\n",
    "    torch.randn(5, 3).to(device='cuda'),\n",
    "    torch.zeros(5, 1).long().to(device='cuda'),\n",
    ")\n",
    "\n",
    "teacher.get_action([1, 2, 3]), teacher.test(play=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(None,\n",
       " (tensor([[ 0.1692, -0.9856, -2.1773],\n",
       "          [-0.3216,  0.9469,  5.5364],\n",
       "          [ 0.2752, -0.9614, -1.6344],\n",
       "          [ 0.9984,  0.0561, -4.5115],\n",
       "          [-0.3889, -0.9213, -6.8181],\n",
       "          [-0.9904,  0.1385, -8.0000],\n",
       "          [ 0.4599,  0.8880, -4.0348],\n",
       "          [ 0.0470,  0.9989, -7.0242],\n",
       "          [ 0.5866, -0.8098, -5.7006],\n",
       "          [ 0.7217,  0.6923, -3.0505],\n",
       "          [ 0.3528, -0.9357, -1.0516],\n",
       "          [ 0.8315,  0.5555,  1.6290],\n",
       "          [-0.2662,  0.9639, -7.4282],\n",
       "          [ 0.9858, -0.1681, -2.1228],\n",
       "          [-0.2587, -0.9660,  5.0915],\n",
       "          [-0.9246, -0.3810, -8.0000],\n",
       "          [-0.9435,  0.3314, -8.0000],\n",
       "          [-0.9683,  0.2497,  6.9810],\n",
       "          [-0.4879, -0.8729, -8.0000],\n",
       "          [ 0.9485,  0.3169,  0.2374],\n",
       "          [ 0.7758, -0.6310, -3.1942],\n",
       "          [-0.4411,  0.8974, -7.5968],\n",
       "          [-0.8245,  0.5658,  6.6795],\n",
       "          [-0.6508, -0.7593, -8.0000],\n",
       "          [-0.8779,  0.4788, -8.0000],\n",
       "          [-0.6927,  0.7212, -7.9284],\n",
       "          [ 0.9858, -0.1682, -4.5014],\n",
       "          [ 0.7964,  0.6048, -5.2489],\n",
       "          [ 0.0441,  0.9990, -5.2449],\n",
       "          [ 0.8928,  0.4504, -4.9587],\n",
       "          [ 0.9315,  0.3636, -5.1935],\n",
       "          [-0.8951, -0.4459, -8.0000],\n",
       "          [ 0.8740,  0.4860,  1.3497],\n",
       "          [ 0.8758,  0.4827, -2.4405],\n",
       "          [ 0.4480,  0.8940, -6.3327],\n",
       "          [ 0.8662, -0.4996, -2.6321],\n",
       "          [ 0.6122,  0.7907, -5.6816],\n",
       "          [-0.8141, -0.5808, -5.9245],\n",
       "          [ 0.7146,  0.6995, -5.5104],\n",
       "          [ 0.7907, -0.6122, -5.1057],\n",
       "          [ 0.5343, -0.8453, -6.4004],\n",
       "          [ 0.0328, -0.9995, -7.6467],\n",
       "          [-0.2779,  0.9606, -7.4574],\n",
       "          [-0.0700, -0.9975, -7.4105],\n",
       "          [ 0.6862,  0.7274, -5.8337],\n",
       "          [ 0.7550,  0.6557, -5.3748],\n",
       "          [ 0.7731, -0.6343, -5.7047],\n",
       "          [-0.0238,  0.9997, -7.1064],\n",
       "          [-0.1095, -0.9940, -7.8610],\n",
       "          [-0.9757,  0.2191, -8.0000],\n",
       "          [ 0.7688,  0.6395,  2.0980],\n",
       "          [-0.0326,  0.9995, -7.0783],\n",
       "          [ 0.2886, -0.9574, -6.8540],\n",
       "          [ 0.9249,  0.3801,  0.7429],\n",
       "          [ 0.9995,  0.0309, -1.6870],\n",
       "          [ 0.9579,  0.2871, -1.9702],\n",
       "          [ 0.8717, -0.4900, -5.1518],\n",
       "          [ 0.6231,  0.7822, -5.7057],\n",
       "          [-0.8369, -0.5473, -8.0000],\n",
       "          [-0.9924,  0.1233, -8.0000],\n",
       "          [ 0.9115,  0.4113, -5.1935],\n",
       "          [ 0.9934,  0.1150, -1.6915],\n",
       "          [-0.5313,  0.8472, -7.7954],\n",
       "          [ 0.0189, -0.9998, -3.0213]], device='cuda:0'),\n",
       "  tensor([[-0.6986],\n",
       "          [-0.5825],\n",
       "          [ 1.1877],\n",
       "          [-0.2130],\n",
       "          [-1.3190],\n",
       "          [-1.3362],\n",
       "          [-0.9792],\n",
       "          [-1.1877],\n",
       "          [-1.4786],\n",
       "          [-1.6326],\n",
       "          [ 0.7928],\n",
       "          [ 0.3494],\n",
       "          [-1.0608],\n",
       "          [ 1.4811],\n",
       "          [-0.3796],\n",
       "          [-1.3578],\n",
       "          [-1.1614],\n",
       "          [-1.3226],\n",
       "          [-1.2559],\n",
       "          [ 0.8808],\n",
       "          [-0.4055],\n",
       "          [-1.3340],\n",
       "          [-0.8190],\n",
       "          [-1.3556],\n",
       "          [-1.2943],\n",
       "          [-1.2388],\n",
       "          [-0.4293],\n",
       "          [-1.0231],\n",
       "          [-0.9661],\n",
       "          [-1.5272],\n",
       "          [-0.9974],\n",
       "          [-1.2165],\n",
       "          [-0.5678],\n",
       "          [-1.3465],\n",
       "          [-1.1435],\n",
       "          [-1.2490],\n",
       "          [-1.0692],\n",
       "          [-1.0287],\n",
       "          [-1.1807],\n",
       "          [-0.9050],\n",
       "          [-0.4866],\n",
       "          [-1.1394],\n",
       "          [-0.8525],\n",
       "          [-1.2875],\n",
       "          [-0.6639],\n",
       "          [-0.5046],\n",
       "          [-1.4662],\n",
       "          [-1.4582],\n",
       "          [-1.3850],\n",
       "          [-1.5057],\n",
       "          [ 0.0382],\n",
       "          [-1.3836],\n",
       "          [-1.1530],\n",
       "          [-0.1740],\n",
       "          [-1.4036],\n",
       "          [-0.4880],\n",
       "          [-1.1935],\n",
       "          [-0.6652],\n",
       "          [-1.2217],\n",
       "          [-1.1579],\n",
       "          [-1.5774],\n",
       "          [-0.5456],\n",
       "          [-1.2440],\n",
       "          [-0.1492]], device='cuda:0'),\n",
       "  tensor([[ -2.4368],\n",
       "          [ -6.6689],\n",
       "          [ -1.9377],\n",
       "          [ -2.0385],\n",
       "          [ -8.5324],\n",
       "          [-15.4179],\n",
       "          [ -2.8234],\n",
       "          [ -7.2572],\n",
       "          [ -4.1428],\n",
       "          [ -1.5178],\n",
       "          [ -1.5759],\n",
       "          [ -0.6124],\n",
       "          [ -8.9054],\n",
       "          [ -0.4813],\n",
       "          [ -5.9505],\n",
       "          [-13.9684],\n",
       "          [-14.2625],\n",
       "          [-13.2230],\n",
       "          [-10.7300],\n",
       "          [ -0.1104],\n",
       "          [ -1.4868],\n",
       "          [ -9.8843],\n",
       "          [-10.9149],\n",
       "          [-11.5976],\n",
       "          [-13.3834],\n",
       "          [-11.7444],\n",
       "          [ -2.0550],\n",
       "          [ -3.1780],\n",
       "          [ -5.0826],\n",
       "          [ -2.6794],\n",
       "          [ -2.8367],\n",
       "          [-13.5807],\n",
       "          [ -0.4400],\n",
       "          [ -0.8512],\n",
       "          [ -5.2354],\n",
       "          [ -0.9681],\n",
       "          [ -4.0608],\n",
       "          [ -9.8709],\n",
       "          [ -3.6381],\n",
       "          [ -3.0417],\n",
       "          [ -5.1110],\n",
       "          [ -8.2138],\n",
       "          [ -8.9935],\n",
       "          [ -8.1858],\n",
       "          [ -4.0671],\n",
       "          [ -3.4005],\n",
       "          [ -3.7286],\n",
       "          [ -7.5949],\n",
       "          [ -9.0055],\n",
       "          [-14.9329],\n",
       "          [ -0.9216],\n",
       "          [ -7.5832],\n",
       "          [ -6.3323],\n",
       "          [ -0.2073],\n",
       "          [ -0.2875],\n",
       "          [ -0.4732],\n",
       "          [ -2.9177],\n",
       "          [ -4.0626],\n",
       "          [-12.9677],\n",
       "          [-15.5093],\n",
       "          [ -2.8794],\n",
       "          [ -0.2997],\n",
       "          [-10.6193],\n",
       "          [ -3.3212]], device='cuda:0'),\n",
       "  tensor([[ 0.0189, -0.9998, -3.0213],\n",
       "          [-0.5935,  0.8048,  6.1591],\n",
       "          [ 0.1692, -0.9856, -2.1773],\n",
       "          [ 0.9858, -0.1682, -4.5014],\n",
       "          [-0.7067, -0.7075, -7.7069],\n",
       "          [-0.8583,  0.5132, -8.0000],\n",
       "          [ 0.6081,  0.7939, -3.5157],\n",
       "          [ 0.3613,  0.9324, -6.4532],\n",
       "          [ 0.2959, -0.9552, -6.5298],\n",
       "          [ 0.8105,  0.5857, -2.7762],\n",
       "          [ 0.2752, -0.9614, -1.6344],\n",
       "          [ 0.7688,  0.6395,  2.0980],\n",
       "          [ 0.0737,  0.9973, -6.8644],\n",
       "          [ 0.9637, -0.2670, -2.0267],\n",
       "          [-0.0462, -0.9989,  4.3101],\n",
       "          [-1.0000,  0.0091, -8.0000],\n",
       "          [-0.7424,  0.6699, -7.9256],\n",
       "          [-0.9954, -0.0960,  6.9699],\n",
       "          [-0.7893, -0.6140, -8.0000],\n",
       "          [ 0.9384,  0.3455,  0.6072],\n",
       "          [ 0.6454, -0.7639, -3.7283],\n",
       "          [-0.1005,  0.9949, -7.1238],\n",
       "          [-0.9683,  0.2497,  6.9810],\n",
       "          [-0.8951, -0.4459, -8.0000],\n",
       "          [-0.6286,  0.7777, -7.8350],\n",
       "          [-0.3770,  0.9262, -7.5733],\n",
       "          [ 0.9197, -0.3927, -4.6919],\n",
       "          [ 0.9203,  0.3913, -4.9488],\n",
       "          [ 0.2726,  0.9621, -4.6406],\n",
       "          [ 0.9749,  0.2228, -4.8500],\n",
       "          [ 0.9930,  0.1184, -5.0703],\n",
       "          [-0.9981, -0.0621, -8.0000],\n",
       "          [ 0.8315,  0.5555,  1.6290],\n",
       "          [ 0.9250,  0.3800, -2.2805],\n",
       "          [ 0.6862,  0.7274, -5.8337],\n",
       "          [ 0.7758, -0.6310, -3.1942],\n",
       "          [ 0.7964,  0.6048, -5.2489],\n",
       "          [-0.9571, -0.2898, -6.5144],\n",
       "          [ 0.8695,  0.4939, -5.1629],\n",
       "          [ 0.5866, -0.8098, -5.7006],\n",
       "          [ 0.2068, -0.9784, -7.1074],\n",
       "          [-0.3590, -0.9333, -8.0000],\n",
       "          [ 0.0615,  0.9981, -6.8648],\n",
       "          [-0.4530, -0.8915, -8.0000],\n",
       "          [ 0.8550,  0.5186, -5.3877],\n",
       "          [ 0.8928,  0.4504, -4.9587],\n",
       "          [ 0.5343, -0.8453, -6.4004],\n",
       "          [ 0.3003,  0.9538, -6.5754],\n",
       "          [-0.4879, -0.8729, -8.0000],\n",
       "          [-0.8134,  0.5817, -8.0000],\n",
       "          [ 0.6800,  0.7332,  2.5834],\n",
       "          [ 0.2899,  0.9570, -6.5362],\n",
       "          [-0.0943, -0.9955, -7.7450],\n",
       "          [ 0.9047,  0.4260,  1.0019],\n",
       "          [ 0.9980, -0.0628, -1.8744],\n",
       "          [ 0.9801,  0.1985, -1.8280],\n",
       "          [ 0.6989, -0.7152, -5.6983],\n",
       "          [ 0.8038,  0.5949, -5.2189],\n",
       "          [-0.9840, -0.1782, -8.0000],\n",
       "          [-0.8660,  0.5000, -8.0000],\n",
       "          [ 0.9860,  0.1670, -5.1216],\n",
       "          [ 0.9995,  0.0309, -1.6870],\n",
       "          [-0.1916,  0.9815, -7.3466],\n",
       "          [-0.1699, -0.9855, -3.7935]], device='cuda:0'),\n",
       "  tensor([[0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [1],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0]], device='cuda:0')))"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class Data:\n",
    "    def __init__(self):\n",
    "        #样本池\n",
    "        self.datas = []\n",
    "\n",
    "    #向样本池中添加N条数据,删除M条最古老的数据\n",
    "    def update_data(self, agent):\n",
    "        #初始化游戏\n",
    "        state = env.reset()\n",
    "\n",
    "        #玩到游戏结束为止\n",
    "        over = False\n",
    "        while not over:\n",
    "            #根据当前状态得到一个动作\n",
    "            action = agent.get_action(state)\n",
    "\n",
    "            #执行动作,得到反馈\n",
    "            next_state, reward, over, _ = env.step([action])\n",
    "\n",
    "            #记录数据样本\n",
    "            self.datas.append((state, action, reward, next_state, over))\n",
    "\n",
    "            #更新游戏状态,开始下一个动作\n",
    "            state = next_state\n",
    "\n",
    "        #数据上限,超出时从最古老的开始删除\n",
    "        while len(self.datas) > 100000:\n",
    "            self.datas.pop(0)\n",
    "\n",
    "    #获取一批数据样本\n",
    "    def get_sample(self):\n",
    "        #从样本池中采样\n",
    "        samples = random.sample(self.datas, 64)\n",
    "\n",
    "        #[b, 3]\n",
    "        state = torch.FloatTensor([i[0] for i in samples]).reshape(-1, 3).to(device='cuda')\n",
    "        #[b, 1]\n",
    "        action = torch.FloatTensor([i[1] for i in samples]).reshape(-1, 1).to(device='cuda')\n",
    "        #[b, 1]\n",
    "        reward = torch.FloatTensor([i[2] for i in samples]).reshape(-1, 1).to(device='cuda')\n",
    "        #[b, 3]\n",
    "        next_state = torch.FloatTensor([i[3] for i in samples]).reshape(-1, 3).to(device='cuda')\n",
    "        #[b, 1]\n",
    "        over = torch.LongTensor([i[4] for i in samples]).reshape(-1, 1).to(device='cuda')\n",
    "\n",
    "        return state, action, reward, next_state, over\n",
    "\n",
    "\n",
    "data = Data()\n",
    "\n",
    "data.update_data(teacher), data.get_sample()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 -1417.5552844437145\n",
      "10 -682.0980111896516\n",
      "20 -338.95231020186\n",
      "30 -177.04914194670872\n",
      "40 -205.4522163828523\n",
      "50 -91.80323189702658\n",
      "60 -181.1353294881537\n",
      "70 -145.59309594014502\n",
      "80 -156.91388656138562\n",
      "90 -146.0304440632086\n"
     ]
    }
   ],
   "source": [
    "for epoch in range(100):\n",
    "    #更新N条数据\n",
    "    data.update_data(teacher)\n",
    "\n",
    "    #每次更新过数据后,学习N次\n",
    "    for i in range(200):\n",
    "        teacher.train(*data.get_sample())\n",
    "\n",
    "    if epoch % 10 == 0:\n",
    "        test_result = sum([teacher.test(play=False) for _ in range(10)]) / 10\n",
    "        print(epoch, test_result)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(0.348809152841568, -1636.895262798944)"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class CQL(SAC):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "\n",
    "    def _get_loss_value(self, model_value, target, state, action, next_state):\n",
    "        #计算value\n",
    "        value = model_value(state, action)\n",
    "\n",
    "        #计算loss,value的目标是要贴近target\n",
    "        loss_value = self.loss_fn(value, target)\n",
    "        \"\"\"以上与SAC相同,以下是CQL的部分\"\"\"\n",
    "\n",
    "        #把state复制5遍\n",
    "        #[b, 3] -> [b, 1, 3] -> [b, 5, 3]\n",
    "        state = state.unsqueeze(dim=1)\n",
    "        #[b, 1, 3] -> [b, 5, 3] -> [b*5, 3]\n",
    "        state = state.repeat(1, 5, 1).reshape(-1, 3)\n",
    "\n",
    "        #把next_state复制5遍\n",
    "        #[b, 3] -> [b, 1, 3]\n",
    "        next_state = next_state.unsqueeze(1)\n",
    "        #[b, 1, 3] -> [b, 5, 3] -> [b*5, 3]\n",
    "        next_state = next_state.repeat(1, 5, 1).reshape(-1, 3)\n",
    "\n",
    "        #随机一批动作,数量是数据量的5倍,值域在-1到1之间\n",
    "        rand_action = torch.empty([len(state), 1]).uniform_(-1, 1).to(device='cuda')\n",
    "\n",
    "        #计算state的动作和熵\n",
    "        #[b*5, 3] -> [b*5, 1],[b*5, 1]\n",
    "        curr_action, curr_entropy = self.model_action(state)\n",
    "\n",
    "        #计算next_state的动作和熵\n",
    "        #[b*5, 3] -> [b*5, 1],[b*5, 1]\n",
    "        next_action, next_entropy = self.model_action(next_state)\n",
    "\n",
    "        #计算三份动作分别的value\n",
    "        #[b*5, 1],[b*5, 1] -> [b*5, 1] -> [b, 5, 1]\n",
    "        value_rand = model_value(state, rand_action).reshape(-1, 5, 1).to(device='cuda')\n",
    "        #[b*5, 1],[b*5, 1] -> [b*5, 1] -> [b, 5, 1]\n",
    "        value_curr = model_value(state, curr_action).reshape(-1, 5, 1).to(device='cuda')\n",
    "        #[b*5, 1],[b*5, 1] -> [b*5, 1] -> [b, 5, 1]\n",
    "        value_next = model_value(state, next_action).reshape(-1, 5, 1).to(device='cuda')\n",
    "\n",
    "        #[b*5, 1] -> [b, 5, 1]\n",
    "        curr_entropy = curr_entropy.detach().reshape(-1, 5, 1).to(device='cuda')\n",
    "        next_entropy = next_entropy.detach().reshape(-1, 5, 1).to(device='cuda')\n",
    "\n",
    "        #三份value分别减去他们的熵\n",
    "        #[b, 5, 1]\n",
    "        value_rand -= math.log(0.5)\n",
    "        #[b, 5, 1]\n",
    "        value_curr -= curr_entropy\n",
    "        #[b, 5, 1]\n",
    "        value_next -= next_entropy\n",
    "\n",
    "        #拼合三份value\n",
    "        #[b, 5+5+5, 1] -> [b, 15, 1]\n",
    "        value_cat = torch.cat([value_rand, value_curr, value_next], dim=1)\n",
    "\n",
    "        #等价t.logsumexp(dim=1), t.exp().sum(dim=1).log()\n",
    "        #[b, 15, 1] -> [b, 1] -> scala\n",
    "        loss_cat = torch.logsumexp(value_cat, dim=1).mean()\n",
    "\n",
    "        #在原本的loss上增加上这一部分\n",
    "        #scala\n",
    "        loss_value += 5.0 * (loss_cat - value.mean())\n",
    "        \"\"\"CQL算法和SCA算法的差异到此为止\"\"\"\n",
    "\n",
    "        return loss_value\n",
    "\n",
    "\n",
    "student = CQL()\n",
    "\n",
    "student.train(\n",
    "    torch.randn(5, 3).to(device='cuda'),\n",
    "    torch.randn(5, 1).to(device='cuda'),\n",
    "    torch.randn(5, 1).to(device='cuda'),\n",
    "    torch.randn(5, 3).to(device='cuda'),\n",
    "    torch.zeros(5, 1).long().to(device='cuda'),\n",
    ")\n",
    "\n",
    "student.get_action([1, 2, 3]), student.test(play=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 -1586.5367885704459\n",
      "2000 -440.8033422306197\n",
      "4000 -744.6367358641508\n",
      "6000 -811.5443056131819\n",
      "8000 -925.059320381306\n",
      "10000 -790.0178368523073\n",
      "12000 -469.19915360793584\n",
      "14000 -335.89593260487214\n",
      "16000 -374.0960422515992\n",
      "18000 -526.956089678878\n",
      "20000 -652.9847198925318\n",
      "22000 -944.5858477027953\n",
      "24000 -579.0335550935424\n",
      "26000 -437.14618826195436\n",
      "28000 -381.48333875498963\n",
      "30000 -323.94055882464437\n",
      "32000 -688.5410013490131\n",
      "34000 -742.9517021081399\n",
      "36000 -403.3001744538331\n",
      "38000 -391.963088438535\n",
      "40000 -542.3478956785744\n",
      "42000 -243.25677029129855\n",
      "44000 -231.09013437208418\n",
      "46000 -222.38982523130048\n",
      "48000 -369.4490839527075\n"
     ]
    }
   ],
   "source": [
    "#训练N次,训练过程中不需要更新数据\n",
    "for i in range(50000):\n",
    "    #采样一批数据\n",
    "    student.train(*data.get_sample())\n",
    "\n",
    "    if i % 2000 == 0:\n",
    "        test_result = sum([student.test(play=False) for _ in range(10)]) / 10\n",
    "        print(i, test_result)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "student.test(play=True)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Gym",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.16"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
